// Copyright (c) Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use crate::execution_scheduler::balance_withdraw_scheduler::ScheduleResult;
use crate::execution_scheduler::balance_withdraw_scheduler::{
    BalanceSettlement, ScheduleStatus, TxBalanceWithdraw, mock_balance_read::MockBalanceRead,
    scheduler::BalanceWithdrawScheduler,
};
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use mysten_metrics::monitored_mpsc::unbounded_channel;
use parking_lot::Mutex;
use rand::{Rng, seq::SliceRandom};
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use sui_macros::sim_test;
use sui_types::{
    accumulator_root::AccumulatorObjId,
    base_types::{ObjectID, SequenceNumber},
    digests::TransactionDigest,
};
use tokio::sync::oneshot;
use tokio::time::error::Elapsed;
use tokio::time::timeout;
use tracing::{debug, info};

#[derive(Clone)]
struct TestScheduler {
    mock_read: Arc<MockBalanceRead>,
    scheduler: BalanceWithdrawScheduler,
}

impl TestScheduler {
    fn new(init_version: SequenceNumber, init_balances: BTreeMap<ObjectID, u128>) -> Self {
        let mock_read = Arc::new(MockBalanceRead::new(init_version, init_balances));
        Self::new_with_mock_read(mock_read)
    }

    fn new_with_mock_read(mock_read: Arc<MockBalanceRead>) -> Self {
        let scheduler = BalanceWithdrawScheduler::new(mock_read.clone(), mock_read.cur_version());
        Self {
            mock_read,
            scheduler,
        }
    }

    fn schedule_withdraws(
        &self,
        version: SequenceNumber,
        withdraws: Vec<TxBalanceWithdraw>,
    ) -> FuturesUnordered<oneshot::Receiver<ScheduleResult>> {
        self.scheduler.schedule_withdraws(version, withdraws)
    }

    fn settle_balance_changes(
        &self,
        next_accumulator_version: SequenceNumber,
        changes: BTreeMap<ObjectID, i128>,
    ) {
        let accumulator_changes: BTreeMap<_, _> = changes
            .iter()
            .map(|(id, value)| (AccumulatorObjId::new_unchecked(*id), *value))
            .collect();
        self.mock_read
            .settle_balance_changes(accumulator_changes.clone(), next_accumulator_version);
        self.scheduler.settle_balances(BalanceSettlement {
            next_accumulator_version,
            balance_changes: accumulator_changes.clone(),
        });
    }

    async fn wait_for_accumulator_version(&self, version: SequenceNumber) {
        while self.scheduler.get_current_accumulator_version() < version {
            tokio::time::sleep(Duration::from_millis(10)).await;
        }
    }
}

async fn wait_for_results(
    mut receivers: FuturesUnordered<oneshot::Receiver<ScheduleResult>>,
    expected_results: BTreeMap<TransactionDigest, ScheduleStatus>,
) -> Result<(), Elapsed> {
    timeout(Duration::from_secs(3), async {
        let mut results = BTreeMap::new();
        while let Some(result) = receivers.next().await {
            let result = result.unwrap();
            results.insert(result.tx_digest, result.status);
        }
        assert_eq!(results, expected_results);
    })
    .await
}

#[tokio::test]
async fn test_schedule_wait_for_settlement() {
    // This test checks that a withdraw cannot be scheduled until
    // a settlement, and if there is no settlement we would lose liveness.
    let init_version = SequenceNumber::from_u64(0);
    let account = ObjectID::random();
    let test = TestScheduler::new(init_version, BTreeMap::from([(account, 100)]));

    let withdraw = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account), 200)]),
    };

    let receivers = test.schedule_withdraws(init_version.next(), vec![withdraw.clone()]);
    assert!(
        wait_for_results(
            receivers,
            BTreeMap::from([(withdraw.tx_digest, ScheduleStatus::SufficientBalance)]),
        )
        .await
        .is_err()
    );
}

#[tokio::test]
async fn test_schedules_and_settles() {
    let v0 = SequenceNumber::from_u64(0);
    let account = ObjectID::random();
    let test = TestScheduler::new(v0, BTreeMap::from([(account, 100)]));

    let withdraw0 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account), 60)]),
    };
    let receivers = test.schedule_withdraws(v0, vec![withdraw0.clone()]);
    wait_for_results(
        receivers,
        BTreeMap::from([(withdraw0.tx_digest, ScheduleStatus::SufficientBalance)]),
    )
    .await
    .unwrap();

    let v1 = v0.next();
    // 100 -> 40, v0 -> v1
    test.settle_balance_changes(v1, BTreeMap::from([(account, -60)]));

    let withdraw1 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account), 60)]),
    };
    let receivers = test.schedule_withdraws(v1, vec![withdraw1.clone()]);
    wait_for_results(
        receivers,
        BTreeMap::from([(withdraw1.tx_digest, ScheduleStatus::InsufficientBalance)]),
    )
    .await
    .unwrap();

    let v2 = v1.next();
    // 40 -> 60, v1 -> v2
    test.settle_balance_changes(v2, BTreeMap::from([(account, 20)]));

    let withdraw2 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account), 60)]),
    };
    let receivers = test.schedule_withdraws(v2, vec![withdraw2.clone()]);
    wait_for_results(
        receivers,
        BTreeMap::from([(withdraw2.tx_digest, ScheduleStatus::SufficientBalance)]),
    )
    .await
    .unwrap();
}

#[tokio::test]
async fn test_already_executed() {
    let init_version = SequenceNumber::from_u64(0);
    let account1 = ObjectID::random();
    let account2 = ObjectID::random();
    let test = TestScheduler::new(
        init_version,
        BTreeMap::from([(account1, 100), (account2, 200)]),
    );

    // Advance the accumulator version
    test.settle_balance_changes(init_version.next(), BTreeMap::new());

    tokio::time::sleep(Duration::from_millis(10)).await;

    // Try to schedule multiple withdraws for the old version
    let withdraw1 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account1), 50)]),
    };
    let withdraw2 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account2), 100)]),
    };

    let receivers =
        test.schedule_withdraws(init_version, vec![withdraw1.clone(), withdraw2.clone()]);
    wait_for_results(
        receivers,
        BTreeMap::from([
            (withdraw1.tx_digest, ScheduleStatus::SkipSchedule),
            (withdraw2.tx_digest, ScheduleStatus::SkipSchedule),
        ]),
    )
    .await
    .unwrap();
}

#[tokio::test]
async fn test_multiple_withdraws_same_version() {
    // This test checks that even though the second withdraw failed due to insufficient balance,
    // the third withdraw can still be scheduled since the second withdraw does not reserve any balance.
    let init_version = SequenceNumber::from_u64(0);
    let account = ObjectID::random();
    let test = TestScheduler::new(init_version, BTreeMap::from([(account, 90)]));

    let withdraw1 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account), 50)]),
    };
    let withdraw2 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account), 50)]),
    };
    let withdraw3 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account), 40)]),
    };

    let receivers = test.schedule_withdraws(
        init_version,
        vec![withdraw1.clone(), withdraw2.clone(), withdraw3.clone()],
    );
    wait_for_results(
        receivers,
        BTreeMap::from([
            (withdraw1.tx_digest, ScheduleStatus::SufficientBalance),
            (withdraw2.tx_digest, ScheduleStatus::InsufficientBalance),
            (withdraw3.tx_digest, ScheduleStatus::SufficientBalance),
        ]),
    )
    .await
    .unwrap();
}

#[tokio::test]
async fn test_multiple_withdraws_multiple_accounts_same_version() {
    let init_version = SequenceNumber::from_u64(0);
    let account1 = ObjectID::random();
    let account2 = ObjectID::random();
    let test = TestScheduler::new(
        init_version,
        BTreeMap::from([(account1, 100), (account2, 100)]),
    );

    let withdraw1 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([
            (AccumulatorObjId::new_unchecked(account1), 100),
            (AccumulatorObjId::new_unchecked(account2), 200),
        ]),
    };
    let withdraw2 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account1), 1)]),
    };
    let withdraw3 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account2), 100)]),
    };

    let receivers = test.schedule_withdraws(
        init_version,
        vec![withdraw1.clone(), withdraw2.clone(), withdraw3.clone()],
    );
    wait_for_results(
        receivers,
        BTreeMap::from([
            (withdraw1.tx_digest, ScheduleStatus::InsufficientBalance),
            (withdraw2.tx_digest, ScheduleStatus::InsufficientBalance),
            (withdraw3.tx_digest, ScheduleStatus::SufficientBalance),
        ]),
    )
    .await
    .unwrap();
}

#[tokio::test]
async fn test_withdraw_already_settled_account_object() {
    let v0 = SequenceNumber::from_u64(0);
    let v1 = v0.next();
    let account = ObjectID::random();
    let account_id = AccumulatorObjId::new_unchecked(account);
    // Mimic the scenario where while we haven't processed the settlement for version `v1`,
    // the underlying store has already observed a newer version of the account object through
    // the execution of settlement transactions.
    let mock_read = Arc::new(MockBalanceRead::new(
        v1,
        BTreeMap::from([(account, 100u128)]),
    ));
    let scheduler = TestScheduler::new_with_mock_read(mock_read.clone());

    let withdraw = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(account_id, 60)]),
    };

    let receivers = scheduler.schedule_withdraws(v0, vec![withdraw.clone()]);
    wait_for_results(
        receivers,
        BTreeMap::from([(withdraw.tx_digest, ScheduleStatus::SkipSchedule)]),
    )
    .await
    .unwrap();

    // Bump the underlying object version to v2.
    // Even though the scheduler itself is still at v0 as the last settled version,
    // withdrawing v1 is still considered as already settled since the object version is already at v2.
    let v2 = v1.next();
    mock_read.settle_balance_changes(BTreeMap::from([(account_id, 0)]), v2);
    let receivers = scheduler.schedule_withdraws(v1, vec![withdraw.clone()]);
    wait_for_results(
        receivers,
        BTreeMap::from([(withdraw.tx_digest, ScheduleStatus::SufficientBalance)]),
    )
    .await
    .unwrap();
}

#[tokio::test]
async fn test_settle_just_updated_account_object() {
    let v0 = SequenceNumber::from_u64(0);
    let v1 = v0.next();
    let v2 = v1.next();
    let account1 = ObjectID::random();
    let account2 = ObjectID::random();
    let mock_read = Arc::new(MockBalanceRead::new(
        v0,
        BTreeMap::from([(account1, 100u128), (account2, 100u128)]),
    ));

    let scheduler = TestScheduler::new_with_mock_read(mock_read.clone());
    // Bump underlying account object versions to v1.
    mock_read.settle_balance_changes(
        BTreeMap::from([
            (AccumulatorObjId::new_unchecked(account1), 0),
            (AccumulatorObjId::new_unchecked(account2), 0),
        ]),
        v1,
    );

    let withdraw1 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account1), 150)]),
    };
    let withdraw2 = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(AccumulatorObjId::new_unchecked(account2), 150)]),
    };
    // Scheduling at v2, with a reservation of 150.
    // Current balance is 100, at version v1.
    // Eager scheduler will put this withdraw in the pending reservations.
    let receivers1 = scheduler.schedule_withdraws(v2, vec![withdraw1.clone()]);
    let receivers2 = scheduler.schedule_withdraws(v2, vec![withdraw2.clone()]);

    // Bring the scheduler to `v1`.
    // The pending withdraw is still pending since the object version is v1.
    scheduler.scheduler.settle_balances(BalanceSettlement {
        next_accumulator_version: v1,
        balance_changes: BTreeMap::new(),
    });
    scheduler.wait_for_accumulator_version(v1).await;
    assert!(
        wait_for_results(
            receivers1,
            BTreeMap::from([(withdraw1.tx_digest, ScheduleStatus::InsufficientBalance)]),
        )
        .await
        .is_err()
    );

    // This will trigger the scheduler to process the pending withdraw.
    scheduler.settle_balance_changes(v2, BTreeMap::new());
    scheduler.wait_for_accumulator_version(v2).await;

    wait_for_results(
        receivers2,
        BTreeMap::from([(withdraw2.tx_digest, ScheduleStatus::InsufficientBalance)]),
    )
    .await
    .unwrap();
}

#[tokio::test]
async fn test_withdraw_settle_and_deleted_account() {
    let v0 = SequenceNumber::from_u64(0);
    let v1 = v0.next();
    let account = ObjectID::random();
    let account_id = AccumulatorObjId::new_unchecked(account);
    // Mimic the scenario where while we haven't processed the settlement for version `v1`,
    // the underlying store has already observed a newer version of the account object through
    // the execution of settlement transactions.
    let mock_read = Arc::new(MockBalanceRead::new(
        v0,
        BTreeMap::from([(account, 100u128)]),
    ));
    let scheduler = TestScheduler::new_with_mock_read(mock_read.clone());

    // Only update the account balance, without calling the scheduler to settle the balances.
    // This means that the scheduler still thinks we are at v0.
    // The settlement of -100 should lead to 0 balance, causing the account to be deleted.
    mock_read.settle_balance_changes(BTreeMap::from([(account_id, -100)]), v1);

    let withdraw = TxBalanceWithdraw {
        tx_digest: TransactionDigest::random(),
        reservations: BTreeMap::from([(account_id, 100)]),
    };

    let receivers = scheduler.schedule_withdraws(v0, vec![withdraw.clone()]);
    wait_for_results(
        receivers,
        BTreeMap::from([(withdraw.tx_digest, ScheduleStatus::SufficientBalance)]),
    )
    .await
    .unwrap();
}

struct StressTestEnv {
    init_balances: BTreeMap<ObjectID, u128>,
    accounts: Vec<ObjectID>,
    withdraws: Vec<(SequenceNumber, Vec<TxBalanceWithdraw>)>,
}

impl StressTestEnv {
    fn new(num_accounts: usize, num_transactions: usize) -> Self {
        let mut version = SequenceNumber::from_u64(0);
        let accounts = (0..num_accounts)
            .map(|_| ObjectID::random())
            .collect::<Vec<_>>();
        let mut rng = rand::thread_rng();
        let init_balances = accounts
            .iter()
            .filter_map(|account_id| {
                if rng.gen_bool(0.7) {
                    Some((*account_id, rng.gen_range(0..20)))
                } else {
                    None
                }
            })
            .collect::<BTreeMap<_, _>>();
        tracing::debug!("Init balances: {:?}", init_balances);

        let mut withdraws = Vec::new();
        let mut cur_reservations = Vec::new();
        for idx in 0..num_transactions {
            let num_reservation_accounts = rng.gen_range(1..3);
            let account_ids = accounts
                .choose_multiple(&mut rng, num_reservation_accounts)
                .cloned()
                .collect::<Vec<_>>();
            let reservations = account_ids
                .iter()
                .map(|account_id| {
                    (
                        AccumulatorObjId::new_unchecked(*account_id),
                        rng.gen_range(1..10),
                    )
                })
                .collect::<BTreeMap<_, _>>();
            cur_reservations.push(TxBalanceWithdraw {
                tx_digest: TransactionDigest::random(),
                reservations,
            });
            // Every now and then we group all withdraws into a commit, which we would
            // generate a settlement for later.
            if rng.gen_bool(0.2) || idx == num_transactions - 1 {
                withdraws.push((version, std::mem::take(&mut cur_reservations)));
                version = version.next();
            }
        }

        Self {
            init_balances,
            accounts,
            withdraws,
        }
    }
}

#[sim_test]
async fn balance_withdraw_scheduler_stress_test() {
    telemetry_subscribers::init_for_testing();

    let num_accounts = 5;
    let num_transactions = 20000;

    info!(
        "Running stress test with num_accounts: {:?}, num_transactions: {:?}",
        num_accounts, num_transactions
    );

    let StressTestEnv {
        init_balances,
        accounts,
        withdraws,
    } = StressTestEnv::new(num_accounts, num_transactions);

    info!("Starting stress test");

    // Repeat the process many times to ensure deterministic results.
    let mut expected_results: Option<BTreeMap<TransactionDigest, ScheduleStatus>> = None;
    let settlements = Arc::new(Mutex::new(Vec::new()));
    for test_run in 0..50 {
        debug!("Running test instance {:?}", test_run);
        let init_balances = init_balances.clone();
        let accounts = accounts.clone();
        let withdraws = withdraws.clone();
        let settlements = settlements.clone();

        let results = tokio::time::timeout(
            Duration::from_secs(30),
            async {
                let mut version = SequenceNumber::from_u64(0);
                let test = TestScheduler::new(
                    version,
                    init_balances,
                );

                // Start a separate thread to run all settlements on the scheduler.
                let test_clone = test.clone();
                let (schedule_results_tx, mut schedule_results_rx) = unbounded_channel::<BTreeMap<AccumulatorObjId, u64>>("test");
                let settle_task = tokio::spawn(async move {
                    let mut idx = 0;
                    while let Some(reserved_amounts) = schedule_results_rx.recv().await {
                        if test_run == 0 {
                            // Only generate random settlements for the first test run.
                            // All future test runs should use the same settlements.
                            let mut rng = rand::thread_rng();
                            let num_changes = rng.gen_range(0..accounts.len());
                            let balance_changes = accounts
                                .choose_multiple(&mut rng, num_changes)
                                .map(|account_id| {
                                    let withdraws = if let Some(reserved_amount) = reserved_amounts.get(&AccumulatorObjId::new_unchecked(*account_id)) {
                                        rng.gen_range(0..*reserved_amount) as i128
                                    } else {
                                        0
                                    };
                                    let deposits = rng.gen_range(0..10) as i128;
                                    let change = deposits - withdraws;
                                    (*account_id, change)
                                })
                                .collect::<BTreeMap<_, _>>();
                            settlements.lock().push(balance_changes);
                        }

                        version = version.next();
                        test_clone.settle_balance_changes(version, settlements.lock()[idx].clone());
                        idx += 1;
                    }
                });

                let mut all_receivers = Vec::new();
                for (version, withdraws) in withdraws {
                    debug!("Test instance scheduling withdraws for version {:?}", version);
                    let receivers = test.schedule_withdraws(version, withdraws.clone());
                    all_receivers.push((version, receivers, withdraws));
                }

                let mut results = BTreeMap::new();
                for (version, receivers, withdraws) in all_receivers {
                    debug!("Test instance waiting for results from version {:?}, receiver count: {}", version, receivers.len());
                    for result in receivers {
                        let result = result.await.unwrap();
                        debug!("Test instance received result for tx {:?} with status {:?} at version {:?}", result.tx_digest, result.status, version);
                        results.insert(result.tx_digest, result.status);
                    }
                    let mut reserved_amounts = BTreeMap::new();
                    for withdraw in withdraws {
                        if results.get(&withdraw.tx_digest) == Some(&ScheduleStatus::SufficientBalance) {
                            for (account_id, reservation) in withdraw.reservations {
                                *reserved_amounts.entry(account_id).or_insert(0) += reservation;
                            }
                        }
                    }
                    schedule_results_tx.send(reserved_amounts).unwrap();
                }
                // Drop the sender so that the settlement task can exit.
                drop(schedule_results_tx);

                // Make sure all settlements are processed.
                settle_task.await.unwrap();

                results
            })
            .await
            .expect("Task timed out after 30 seconds");
        if let Some(expected_results) = &expected_results {
            assert_eq!(results.len(), expected_results.len());
            for (tx_digest, status) in results {
                assert_eq!(
                    &status,
                    expected_results.get(&tx_digest).unwrap(),
                    "Tx digest: {:?}",
                    tx_digest
                );
            }
        } else {
            expected_results = Some(results);
        }
    }
}
